{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Demo - Convert to Bayesian Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "import torchbnn as bnn\n",
    "from torchhk import transform_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Define Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CNN, self).__init__()\n",
    "        \n",
    "        self.conv_layer = nn.Sequential(\n",
    "            nn.Conv2d(1,3,3),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(2,2)\n",
    "        )\n",
    "        \n",
    "        self.fc_layer = nn.Sequential(\n",
    "            nn.Linear(3*2*2,3*2),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(3*2,2)\n",
    "        )       \n",
    "        \n",
    "    def forward(self,x):\n",
    "        out = self.conv_layer(x)\n",
    "        out = out.view(-1,3*2*2)\n",
    "        out = self.fc_layer(out)\n",
    "\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CNN()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CNN(\n",
       "  (conv_layer): Sequential(\n",
       "    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (fc_layer): Sequential(\n",
       "    (0): Linear(in_features=12, out_features=6, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Linear(in_features=6, out_features=2, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Convert Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1. Nonbayes to Bayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\slcf\\Anaconda3\\lib\\site-packages\\torchhk\\transform.py:31: Warning: \n",
      " * Caution : The Input Model is CHANGED because inplace=True.\n",
      "  warnings.warn(\"\\n * Caution : The Input Model is CHANGED because inplace=True.\", Warning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CNN(\n",
       "  (conv_layer): Sequential(\n",
       "    (0): BayesConv2d(0, 0.1, 1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (fc_layer): Sequential(\n",
       "    (0): Linear(in_features=12, out_features=6, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Linear(in_features=6, out_features=2, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Convert Conv2d -> BayesConv2d\n",
    "transform_model(model, nn.Conv2d, bnn.BayesConv2d, \n",
    "                args={\"prior_mu\":0, \"prior_sigma\":0.1, \"in_channels\" : \".in_channels\",\n",
    "                      \"out_channels\" : \".out_channels\", \"kernel_size\" : \".kernel_size\",\n",
    "                      \"stride\" : \".stride\", \"padding\" : \".padding\", \"bias\":\".bias\"\n",
    "                     }, \n",
    "                attrs={\"weight_mu\" : \".weight\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\slcf\\Anaconda3\\lib\\site-packages\\torchhk\\transform.py:31: Warning: \n",
      " * Caution : The Input Model is CHANGED because inplace=True.\n",
      "  warnings.warn(\"\\n * Caution : The Input Model is CHANGED because inplace=True.\", Warning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CNN(\n",
       "  (conv_layer): Sequential(\n",
       "    (0): BayesConv2d(0, 0.1, 1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (fc_layer): Sequential(\n",
       "    (0): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=12, out_features=6, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=6, out_features=2, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Convert Linear -> BayesLinear\n",
    "transform_model(model, nn.Linear, bnn.BayesLinear, \n",
    "            args={\"prior_mu\":0, \"prior_sigma\":0.1, \"in_features\" : \".in_features\",\n",
    "                  \"out_features\" : \".out_features\", \"bias\":\".bias\"\n",
    "                 }, \n",
    "            attrs={\"weight_mu\" : \".weight\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CNN(\n",
       "  (conv_layer): Sequential(\n",
       "    (0): BayesConv2d(0, 0.1, 1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (fc_layer): Sequential(\n",
       "    (0): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=12, out_features=6, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=6, out_features=2, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2. Bayes to Nonbayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\slcf\\Anaconda3\\lib\\site-packages\\torchhk\\transform.py:31: Warning: \n",
      " * Caution : The Input Model is CHANGED because inplace=True.\n",
      "  warnings.warn(\"\\n * Caution : The Input Model is CHANGED because inplace=True.\", Warning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CNN(\n",
       "  (conv_layer): Sequential(\n",
       "    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (fc_layer): Sequential(\n",
       "    (0): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=12, out_features=6, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=6, out_features=2, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Convert BayesConv2d -> Conv2d\n",
    "transform_model(model, bnn.BayesConv2d, nn.Conv2d,\n",
    "                args={\"in_channels\" : \".in_channels\", \"out_channels\" : \".out_channels\",\n",
    "                      \"kernel_size\" : \".kernel_size\",\n",
    "                      \"padding\" : \".padding\", \"bias\":\".bias\"\n",
    "                     }, \n",
    "                attrs={\"weight\" : \".weight_mu\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\slcf\\Anaconda3\\lib\\site-packages\\torchhk\\transform.py:31: Warning: \n",
      " * Caution : The Input Model is CHANGED because inplace=True.\n",
      "  warnings.warn(\"\\n * Caution : The Input Model is CHANGED because inplace=True.\", Warning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CNN(\n",
       "  (conv_layer): Sequential(\n",
       "    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (fc_layer): Sequential(\n",
       "    (0): Linear(in_features=12, out_features=6, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Linear(in_features=6, out_features=2, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Convert BayesLinear -> Linear\n",
    "transform_model(model, bnn.BayesLinear, nn.Linear, \n",
    "            args={\"in_features\" : \".in_features\", \"out_features\" : \".out_features\",\n",
    "                  \"bias\":\".bias\"\n",
    "                 }, \n",
    "            attrs={\"weight\" : \".weight_mu\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CNN(\n",
       "  (conv_layer): Sequential(\n",
       "    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (1): ReLU()\n",
       "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (fc_layer): Sequential(\n",
       "    (0): Linear(in_features=12, out_features=6, bias=True)\n",
       "    (1): ReLU()\n",
       "    (2): Linear(in_features=6, out_features=2, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
